{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "1e007213",
   "metadata": {},
   "source": [
    "# 在SecretFlow中使用自定义DataBuilder（Torch）"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "389a89ef",
   "metadata": {},
   "source": [
    "The following codes are demos only. It's **NOT for production** due to system security concerns, please **DO NOT** use it directly in production."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "8528a86e",
   "metadata": {},
   "source": [
    "本教程将展示下，怎样在SecretFlow的多方安全环境中，如何使用自定义DataBuilder模式加载数据，并训练模型。\n",
    "本教程将使用Flower数据集的图像分类任务来进行介绍，如何使用自定义DataBuilder完成联邦学习"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "b03e55b2",
   "metadata": {},
   "source": [
    "## 环境设置"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c08ecd6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "812c6aea",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-04-17 15:14:51,955\tINFO worker.py:1538 -- Started a local Ray instance.\n"
     ]
    }
   ],
   "source": [
    "import secretflow as sf\n",
    "\n",
    "# Check the version of your SecretFlow\n",
    "print('The version of SecretFlow: {}'.format(sf.__version__))\n",
    "\n",
    "# In case you have a running secretflow runtime already.\n",
    "sf.shutdown()\n",
    "sf.init(['alice', 'bob', 'charlie'], address=\"local\", log_to_driver=False)\n",
    "alice, bob ,charlie = sf.PYU('alice'), sf.PYU('bob') , sf.PYU('charlie')"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "25c38ba4",
   "metadata": {},
   "source": [
    "## 接口介绍"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "509e5461",
   "metadata": {},
   "source": [
    "我们在SecretFlow的`FLModel`中支持了自定义DataBuilder的读取方式，可以方便用户根据需求更灵活的处理数据输入。\n",
    "下面我们以一个例子来展示下，如何使用自定义DataBuilder来进行联邦模型训练。"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "ebfcf77f",
   "metadata": {},
   "source": [
    "使用DataBuilder的步骤：\n",
    "1. 使用单机版本pytorch引擎进行开发，完成pytorch下构建DataLoader的DataBuilder函数。*注：dataset_builder函数需要传入stage参数*\n",
    "2. 将各方的DataBuilder函数进行wrap，得到create_dataset_builder\n",
    "3. 构造data_builder_dict [PYU,dataset_builder]\n",
    "4. 将得到的data_builder_dict作为参数传入`fit`函数的`dataset_builder`。此时`x`参数位置传入dataset_builder中需要的输入。（eg:本例中传入的输入是实际使用的图像路径）\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "fad75fd6",
   "metadata": {},
   "source": [
    "在FLModel中使用DataBuilder需要预先定义databuilder dict。需要能够返回`tf.dataset`和`steps_per_epoch`。而且各方返回的steps_per_epoch必须保持一致。\n",
    "```python\n",
    "data_builder_dict = \n",
    "        {\n",
    "            alice: create_alice_dataset_builder(\n",
    "                batch_size=32,\n",
    "            ), # create_alice_dataset_builder must return (Dataset, steps_per_epoch)\n",
    "            bob: create_bob_dataset_builder(\n",
    "                batch_size=32,\n",
    "            ), # create_bob_dataset_builder must return (Dataset, steps_per_epochstep_per_epochs)\n",
    "        }\n",
    "\n",
    "```"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "c55b0faf",
   "metadata": {},
   "source": [
    "## 下载数据"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "07ffc09e",
   "metadata": {},
   "source": [
    "Flower数据集介绍：Flower数据集是一个包含了5种花卉（郁金香、黄水仙、鸢尾花、百合、向日葵）共计4323张彩色图片的数据集。每种花卉都有多个角度和不同光照下的图片，每张图片的分辨率为320x240。这个数据集常用于图像分类和机器学习算法的训练与测试。数据集中每个类别的数量分别是：daisy（633），dandelion（898），rose（641），sunflower（699），tulip（852）  \n",
    "  \n",
    "下载地址: [http://download.tensorflow.org/example_images/flower_photos.tgz](http://download.tensorflow.org/example_images/flower_photos.tgz)\n",
    "<img alt=\"flower_dataset_demo.png\" src=\"resources/flower_dataset_demo.png\" width=\"600\">  \n",
    "\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "ef1f5d14",
   "metadata": {},
   "source": [
    "### 下载数据并解压"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ff9720cd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading data from https://secretflow-data.oss-accelerate.aliyuncs.com/datasets/tf_flowers/flower_photos.tgz\n",
      "67588319/67588319 [==============================] - 1s 0us/step\n"
     ]
    }
   ],
   "source": [
    "# 这里复用tf的接口下载图片，输出是一个文件夹，如下图所示\n",
    "import tempfile\n",
    "import tensorflow as tf\n",
    "\n",
    "_temp_dir = tempfile.mkdtemp()\n",
    "path_to_flower_dataset = tf.keras.utils.get_file(\n",
    "    \"flower_photos\",\n",
    "    \"https://secretflow-data.oss-accelerate.aliyuncs.com/datasets/tf_flowers/flower_photos.tgz\",\n",
    "    untar=True,\n",
    "    cache_dir=_temp_dir,\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "41f10d04",
   "metadata": {},
   "source": [
    "# 接下来我们开始构造自定义DataBuilder"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "cba8e9cd",
   "metadata": {},
   "source": [
    "## 1. 使用单机引擎开发DataBuilder"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "b440fcfe",
   "metadata": {},
   "source": [
    "我们在开发`DataBuilder`的时候可以自由的按照单机开发的逻辑即可。  \n",
    "目的是构建一个`Torch`中`Dataloader`对象即可"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "4160dfd0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "\n",
    "import numpy as np\n",
    "from torch.utils.data import DataLoader\n",
    "from torch.utils.data.sampler import SubsetRandomSampler\n",
    "from torchvision import datasets, transforms\n",
    "\n",
    "# parameter\n",
    "batch_size = 32\n",
    "shuffle = True\n",
    "random_seed = 1234\n",
    "train_split = 0.8\n",
    "\n",
    "# Define dataset\n",
    "flower_transform = transforms.Compose(\n",
    "    [\n",
    "        transforms.Resize((180, 180)),\n",
    "        transforms.ToTensor(),\n",
    "    ]\n",
    ")\n",
    "flower_dataset = datasets.ImageFolder(path_to_flower_dataset, transform=flower_transform)\n",
    "dataset_size = len(flower_dataset)\n",
    "# Define sampler\n",
    "\n",
    "indices = list(range(dataset_size))\n",
    "if shuffle:\n",
    "    np.random.seed(random_seed)\n",
    "    np.random.shuffle(indices)\n",
    "split = int(np.floor(train_split * dataset_size))\n",
    "train_indices, val_indices = indices[:split], indices[split:]\n",
    "train_sampler = SubsetRandomSampler(train_indices)\n",
    "valid_sampler = SubsetRandomSampler(val_indices)\n",
    "\n",
    "# Define databuilder\n",
    "train_loader = DataLoader(\n",
    "    flower_dataset, batch_size=batch_size, sampler=train_sampler\n",
    ")\n",
    "valid_loader = DataLoader(\n",
    "    flower_dataset, batch_size=batch_size, sampler=valid_sampler\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "85582af2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x.shape = torch.Size([32, 3, 180, 180])\n",
      "y.shape = torch.Size([32])\n"
     ]
    }
   ],
   "source": [
    "x,y = next(iter(train_loader))\n",
    "print(f\"x.shape = {x.shape}\")\n",
    "print(f\"y.shape = {y.shape}\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "ae5d314a",
   "metadata": {},
   "source": [
    "## 2. 将开发完成的DataBuilder进行包装(wrap)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "bcbd4c8e",
   "metadata": {},
   "source": [
    "我们开发好的DataBuilder在运行是需要分发到各个执行机器上去执行，为了序列化，我们需要把他们进行wrap。  \n",
    "需要注意的是：\n",
    "- FLModel要求DataBuilder的输入必须包含stage参数（stage=\"train）\n",
    "- FLModel要求传入的DataBuilder需要返回两个结果（data_set，steps_per_epoch）**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "c051b566",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_dataset_builder(\n",
    "        batch_size=32,\n",
    "        train_split=0.8,\n",
    "        shuffle=True,\n",
    "        random_seed=1234,\n",
    "    ):\n",
    "        def dataset_builder(x, stage=\"train\"):\n",
    "            \"\"\"\n",
    "            \"\"\"\n",
    "            import math\n",
    "\n",
    "            import numpy as np\n",
    "            from torch.utils.data import DataLoader\n",
    "            from torch.utils.data.sampler import SubsetRandomSampler\n",
    "            from torchvision import datasets, transforms\n",
    "\n",
    "            # Define dataset\n",
    "            flower_transform = transforms.Compose(\n",
    "                [\n",
    "                    transforms.Resize((180, 180)),\n",
    "                    transforms.ToTensor(),\n",
    "                ]\n",
    "            )\n",
    "            flower_dataset = datasets.ImageFolder(x, transform=flower_transform)\n",
    "            dataset_size = len(flower_dataset)\n",
    "            # Define sampler\n",
    "\n",
    "            indices = list(range(dataset_size))\n",
    "            if shuffle:\n",
    "                np.random.seed(random_seed)\n",
    "                np.random.shuffle(indices)\n",
    "            split = int(np.floor(train_split * dataset_size))\n",
    "            train_indices, val_indices = indices[:split], indices[split:]\n",
    "            train_sampler = SubsetRandomSampler(train_indices)\n",
    "            valid_sampler = SubsetRandomSampler(val_indices)\n",
    "\n",
    "            # Define databuilder\n",
    "            train_loader = DataLoader(\n",
    "                flower_dataset, batch_size=batch_size, sampler=train_sampler\n",
    "            )\n",
    "            valid_loader = DataLoader(\n",
    "                flower_dataset, batch_size=batch_size, sampler=valid_sampler\n",
    "            )\n",
    "\n",
    "            # Return\n",
    "            if stage == \"train\":\n",
    "                train_step_per_epoch = math.ceil(split / batch_size)\n",
    "                \n",
    "                return train_loader, train_step_per_epoch\n",
    "            elif stage == \"eval\":\n",
    "                eval_step_per_epoch = math.ceil((dataset_size - split) / batch_size)\n",
    "                return valid_loader, eval_step_per_epoch\n",
    "\n",
    "        return dataset_builder"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "0ffc7776",
   "metadata": {},
   "source": [
    "## 3. 构建dataset_builder_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a9659cdc",
   "metadata": {},
   "outputs": [],
   "source": [
    "#prepare dataset dict\n",
    "data_builder_dict = {\n",
    "    alice: create_dataset_builder(\n",
    "        batch_size=32,\n",
    "        train_split=0.8,\n",
    "        shuffle=False,\n",
    "        random_seed=1234,\n",
    "    ),\n",
    "    bob: create_dataset_builder(\n",
    "        batch_size=32,\n",
    "        train_split=0.8,\n",
    "        shuffle=False,\n",
    "        random_seed=1234,\n",
    "    ),\n",
    "}"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "77dd0f46",
   "metadata": {},
   "source": [
    "## 4. 得到dataset_builder_dict我们就可以使用它进行联邦训练了"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "b954742c",
   "metadata": {},
   "source": [
    "# 接下来我们定义一个Torch后端的FLModel来进行训练"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "b1219912",
   "metadata": {},
   "source": [
    "### 定义模型结构"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "feea334e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from secretflow.ml.nn.utils import BaseModule\n",
    "\n",
    "class ConvRGBNet(BaseModule):\n",
    "    def __init__(self, *args, **kwargs) -> None:\n",
    "        super().__init__(*args, **kwargs)\n",
    "        self.network = nn.Sequential(\n",
    "            nn.Conv2d(\n",
    "                in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1\n",
    "            ),\n",
    "            nn.ReLU(),\n",
    "            nn.MaxPool2d(2, 2),\n",
    "            nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1),\n",
    "            nn.ReLU(),\n",
    "            nn.MaxPool2d(2, 2),\n",
    "            nn.Flatten(),\n",
    "            nn.Linear(16 * 45 * 45, 128),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(128, 5),\n",
    "        )\n",
    "\n",
    "    def forward(self, xb):\n",
    "        return self.network(xb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "60414cb0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from secretflow.ml.nn import FLModel\n",
    "from secretflow.security.aggregation import SecureAggregator\n",
    "from torch import nn, optim\n",
    "from torchmetrics import Accuracy, Precision\n",
    "from secretflow.ml.nn.fl.utils import metric_wrapper, optim_wrapper\n",
    "from secretflow.ml.nn.utils import TorchModel\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "f368538f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:root:Create proxy actor <class 'secretflow.security.aggregation.secure_aggregator._Masker'> with party alice.\n",
      "INFO:root:Create proxy actor <class 'secretflow.security.aggregation.secure_aggregator._Masker'> with party bob.\n",
      "INFO:root:Create proxy actor <class 'secretflow.ml.nn.fl.backend.torch.strategy.fed_avg_w.PYUFedAvgW'> with party alice.\n",
      "INFO:root:Create proxy actor <class 'secretflow.ml.nn.fl.backend.torch.strategy.fed_avg_w.PYUFedAvgW'> with party bob.\n"
     ]
    }
   ],
   "source": [
    "device_list = [alice, bob]\n",
    "aggregator = SecureAggregator(charlie,[alice,bob])\n",
    "# prepare model\n",
    "num_classes = 5\n",
    "\n",
    "input_shape = (180, 180, 3)\n",
    "# torch model\n",
    "loss_fn = nn.CrossEntropyLoss\n",
    "optim_fn = optim_wrapper(optim.Adam, lr=1e-3)\n",
    "model_def = TorchModel(\n",
    "    model_fn=ConvRGBNet,\n",
    "    loss_fn=loss_fn,\n",
    "    optim_fn=optim_fn,\n",
    "    metrics=[\n",
    "        metric_wrapper(\n",
    "            Accuracy, task=\"multiclass\", num_classes=num_classes, average='micro'\n",
    "        ),\n",
    "        metric_wrapper(\n",
    "            Precision, task=\"multiclass\", num_classes=num_classes, average='micro'\n",
    "        ),\n",
    "    ],\n",
    ")\n",
    "\n",
    "fed_model = FLModel(\n",
    "    device_list=device_list,\n",
    "    model=model_def,\n",
    "    aggregator=aggregator,\n",
    "    backend=\"torch\",\n",
    "    strategy=\"fed_avg_w\",\n",
    "    random_seed=1234,\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "9523afee",
   "metadata": {},
   "source": [
    "我们构造好的dataset builder的输入是图像数据集的路径，所以这里需要将输入的数据设置为一个`Dict`\n",
    "```python\n",
    "data = {\n",
    "    alice: folder_path_of_alice,\n",
    "    bob: folder_path_of_bob\n",
    "}\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "de4b659a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:root:FL Train Params: {'self': <secretflow.ml.nn.fl.fl_model.FLModel object at 0x7ff4efc87af0>, 'x': {alice: '/tmp/tmp59nrtvl5/datasets/flower_photos', bob: '/tmp/tmp59nrtvl5/datasets/flower_photos'}, 'y': None, 'batch_size': 32, 'batch_sampling_rate': None, 'epochs': 5, 'verbose': 1, 'callbacks': None, 'validation_data': {alice: '/tmp/tmp59nrtvl5/datasets/flower_photos', bob: '/tmp/tmp59nrtvl5/datasets/flower_photos'}, 'shuffle': False, 'class_weight': None, 'sample_weight': None, 'validation_freq': 1, 'aggregate_freq': 2, 'label_decoder': None, 'max_batch_size': 20000, 'prefetch_buffer_size': None, 'sampler_method': 'batch', 'random_seed': 1234, 'dp_spent_step_freq': 1, 'audit_log_dir': None, 'dataset_builder': {alice: <function create_dataset_builder.<locals>.dataset_builder at 0x7ff600fb7ee0>, bob: <function create_dataset_builder.<locals>.dataset_builder at 0x7ff6007148b0>}}\n",
      "100%|██████████| 30/30 [00:32<00:00,  1.08s/it, epoch: 1/5 -  multiclassaccuracy:0.3760416805744171  multiclassprecision:0.3760416805744171  val_multiclassaccuracy:0.0  val_multiclassprecision:0.0 ]\n",
      "100%|██████████| 8/8 [00:10<00:00,  1.27s/it, epoch: 2/5 -  multiclassaccuracy:0.5078125  multiclassprecision:0.5078125  val_multiclassaccuracy:0.1618257313966751  val_multiclassprecision:0.1618257313966751 ]\n",
      "100%|██████████| 8/8 [00:10<00:00,  1.28s/it, epoch: 3/5 -  multiclassaccuracy:0.51171875  multiclassprecision:0.51171875  val_multiclassaccuracy:0.004149377811700106  val_multiclassprecision:0.004149377811700106 ]\n",
      "100%|██████████| 8/8 [00:10<00:00,  1.27s/it, epoch: 4/5 -  multiclassaccuracy:0.5390625  multiclassprecision:0.5390625  val_multiclassaccuracy:0.02074688859283924  val_multiclassprecision:0.02074688859283924 ]\n",
      "100%|██████████| 8/8 [00:10<00:00,  1.28s/it, epoch: 5/5 -  multiclassaccuracy:0.5703125  multiclassprecision:0.5703125  val_multiclassaccuracy:0.016597511246800423  val_multiclassprecision:0.016597511246800423 ]\n"
     ]
    }
   ],
   "source": [
    "data={\n",
    "        alice: path_to_flower_dataset,\n",
    "        bob: path_to_flower_dataset,\n",
    "    }\n",
    "history = fed_model.fit(\n",
    "    data,\n",
    "    None,\n",
    "    validation_data=data,\n",
    "    epochs=5,\n",
    "    batch_size=32,\n",
    "    aggregate_freq=2,\n",
    "    sampler_method=\"batch\",\n",
    "    random_seed=1234,\n",
    "    dp_spent_step_freq=1,\n",
    "    dataset_builder=data_builder_dict\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
